import torch.optim as optim
import torch
from model import *
from dataset import *
import torch.nn as nn

import torch.optim as optim

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = GazeNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10):
    best_val_acc = 0.0
    best_model_path = 'best_model.pth'

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        val_acc = evaluate_model(model, val_loader)
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}, Val Acc: {val_acc:.4f}')

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), best_model_path)
            print(f'Best model saved with accuracy: {best_val_acc:.4f}')

def evaluate_model(model, data_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

# 训练模型
train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=200)

# 测试模型
model.load_state_dict(torch.load('best_model.pth'))
test_acc = evaluate_model(model, test_loader)
print(f'Accuracy of the best model on the test images: {test_acc:.4f}%')

